// @file data.hpp
// @brief Basic data structures (CUDA support)
// @author Andrea Vedaldi

/*
Copyright (C) 2015 Andrea Vedaldi.
All rights reserved.

This file is part of the VLFeat library and is made available under
the terms of the BSD license (see the COPYING file).
*/

#ifndef __vl__datacu__
#define __vl__datacu__

#ifndef ENABLE_GPU
#error "datacu.hpp cannot be compiled without GPU support"
#endif

#include "data.hpp"
#include <string>
#include <cuda.h>
#include <cublas_v2.h>
#if __CUDA_ARCH__ >= 200
#define VL_CUDA_NUM_THREADS 1024
#else
#define VL_CUDA_NUM_THREADS 512
#endif

#ifdef ENABLE_CUDNN
#include <cudnn.h>
#endif

namespace vl {
  class CudaHelper {
  public:
    // Cuda errors
    cudaError_t getLastCudaError() const ;
    std::string const& getLastCudaErrorMessage() const ;
    vl::Error catchCudaError(char const* description = NULL) ;

    // CuBLAS support
    cublasStatus_t getCublasHandle(cublasHandle_t* handle) ;
    void clearCublas() ;
    cublasStatus_t getLastCublasError() const ;
    std::string const& getLastCublasErrorMessage() const ;
    vl::Error catchCublasError(cublasStatus_t status,
                               char const* description = NULL) ;

#if ENABLE_CUDNN
    // CuDNN support
    cudnnStatus_t getCudnnHandle(cudnnHandle_t* handle) ;
    void clearCudnn() ;
    bool getCudnnEnabled() const ;
    void setCudnnEnabled(bool active) ;
    cudnnStatus_t getLastCudnnError() const ;
    std::string const& getLastCudnnErrorMessage() const ;
    vl::Error catchCudnnError(cudnnStatus_t status,
                              char const* description = NULL) ;
#endif

  protected:
    CudaHelper() ;
    ~CudaHelper() ;
    void clear() ;
    void invalidateGpu() ;
    friend class Context ;

  private:
    cudaError_t lastCudaError ;
    std::string lastCudaErrorMessage ;

    // CuBLAS
    cublasHandle_t cublasHandle ;
    bool isCublasInitialized ;
    cublasStatus_t lastCublasError ;
    std::string lastCublasErrorMessage ;

#if ENABLE_CUDNN
    // CuDNN
    cudnnStatus_t lastCudnnError ;
    std::string lastCudnnErrorMessage ;
    cudnnHandle_t cudnnHandle ;
    bool isCudnnInitialized ;
    bool cudnnEnabled ;
#endif
  } ;
}
#endif /* defined(__vl__datacu__) */
